# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import numpy as np

def MVT_arms_identif(dim,var):
    #Create identifiable MVT design matrix based on a given number of dimensions and variations
    
    #Create main effects matrix
    main = np.eye(var)
    main_matrix = main
    for i in range(1,dim):
        cur_dim = main_matrix.shape[0]
        new_main_matrix = np.repeat(main_matrix, [var]*cur_dim, axis=0)
        main_matrix = np.c_[new_main_matrix, np.row_stack([main]*cur_dim)]
    
    #delete first category:
    main_matrix = np.delete(main_matrix, list(range(0,dim*var,var)), axis = 1)
    var = var - 1
    
    #Create pairwise interactions matrix
    interaction_dim = int(var**2*dim*(dim-1)/2)
    interactions_matrix =  np.array([]).reshape(0,interaction_dim)
    for i in range(len(main_matrix)):
        outer = np.outer(main_matrix[i], main_matrix[i])
        interactions = np.array([]).reshape(1,0)
        for k in range(1,dim):
            rc = var*k
            mat_interactions = outer[(rc-var):rc, rc:]
            interactions = np.c_[interactions, mat_interactions.reshape(1,-1)]
        interactions_matrix = np.r_[interactions_matrix, interactions]

    final_factor_matrix = np.c_[np.array([1]*main_matrix.shape[0]).T, main_matrix, interactions_matrix]
    
    return(final_factor_matrix)